{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\"\"\"\n",
                "Please run notebook locally (if you have all the dependencies and a GPU). \n",
                "Technically you can run this notebook on Google Colab but you need to set up microphone for Colab.\n",
                " \n",
                "Instructions for setting up Colab are as follows:\n",
                "1. Open a new Python 3 notebook.\n",
                "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
                "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
                "4. Run this cell to set up dependencies.\n",
                "5. Set up microphone for Colab\n",
                "\n\nNOTE: User is responsible for checking the content of datasets and the applicable licenses and determining if suitable for the intended use.\n",
                "\"\"\"\n",
                "# If you're using Google Colab and not running locally, run this cell.\n",
                "\n",
                "## Install dependencies\n",
                "!pip install wget\n",
                "!apt-get install sox libsndfile1 ffmpeg portaudio19-dev\n",
                "!pip install text-unidecode\n",
                "!pip install pyaudio\n",
                "\n",
                "# ## Install NeMo\n",
                "BRANCH = 'main'\n",
                "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr]\n",
                "\n",
                "## Install TorchAudio\n",
                "!pip install torchaudio>=0.13.0 -f https://download.pytorch.org/whl/torch_stable.html"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Voice Activity Detection (VAD)\n",
                "\n",
                "\n",
                "This notebook demonstrates how to perform\n",
                "1. [offline streaming inference on audio files (offline VAD)](#Offline-streaming-inference);\n",
                "2. [finetuning](#Finetune) and use [posterior](#Posterior);\n",
                "3. [vad postprocessing and threshold tuning](#VAD-postprocessing-and-Tuning-threshold);\n",
                "4. [online streaming inference](#Online-streaming-inference);\n",
                "5. [online streaming inference from a microphone's stream](#Online-streaming-inference-through-microphone).\n",
		"\n",
		"Note the incompatibility of components could lead to failure of running this notebook locally with container, we might deprecate this notebook and provide a better tutorial in soon releases."
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "The notebook requires PyAudio library to get a signal from an audio device.\n",
                "For Ubuntu, please run the following commands to install it:\n",
                "```\n",
                "sudo apt install python3-pyaudio\n",
                "pip install pyaudio\n",
                "```"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "This notebook requires the `torchaudio` library to be installed for MarbleNet. Please follow the instructions available at the [torchaudio installer](https://github.com/NVIDIA/NeMo/blob/main/scripts/installers/install_torchaudio_latest.sh) and [torchaudio Github page](https://github.com/pytorch/audio#installation) to install the appropriate version of torchaudio.\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import numpy as np\n",
                "import pyaudio as pa\n",
                "import os, time\n",
                "import librosa\n",
                "import IPython.display as ipd\n",
                "import matplotlib.pyplot as plt\n",
                "%matplotlib inline\n",
                "\n",
                "import nemo\n",
                "import nemo.collections.asr as nemo_asr"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# sample rate, Hz\n",
                "SAMPLE_RATE = 16000"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Restore the model from NGC"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "scrolled": true
            },
            "outputs": [],
            "source": [
                "vad_model = nemo_asr.models.EncDecClassificationModel.from_pretrained('vad_marblenet')"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Observing the config of the model"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from omegaconf import OmegaConf\n",
                "import copy"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# Preserve a copy of the full config\n",
                "cfg = copy.deepcopy(vad_model._cfg)\n",
                "print(OmegaConf.to_yaml(cfg))"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Setup preprocessor with these settings"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "vad_model.preprocessor = vad_model.from_config_dict(cfg.preprocessor)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# Set model to inference mode\n",
                "vad_model.eval();"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "vad_model = vad_model.to(vad_model.device)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "We demonstrate two methods for streaming inference:\n",
                "1. [offline streaming inference (script)](#Offline-streaming-inference)\n",
                "2. [online streaming inference (step-by-step)](#Online-streaming-inference)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Offline streaming inference\n",
                "\n",
                "VAD relies on shorter fixed-length segments for prediction. \n",
                "\n",
                "You can find all necessary steps about inference in \n",
                "```python\n",
                "    Script: <NeMo_git_root>/examples/asr/speech_classification/vad_infer.py  \n",
                "    Config: <NeMo_git_root>/examples/asr/conf/vad/vad_inference_postprocessing.yaml\n",
                "```\n",
                "Duration inference, we generate frame-level prediction by two approaches:\n",
                "\n",
                "1. shift the window of length `window_length_in_sec` (e.g. 0.63s) by `shift_length_in_sec` (e.g. 10ms) to generate the frame and use the prediction of the window to represent the label for the frame; Use \n",
                "```python\n",
                " <NeMo_git_root>/examples/asr/speech_classification/vad_infer.py\n",
                "```\n",
                "\n",
                "    This script will automatically split long audio file to avoid CUDA memory issue and performing **streaming** inside `AudioLabelDataset`."
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### Posterior\n",
                "<img src=\"https://raw.githubusercontent.com/NVIDIA/NeMo/v1.0.2/tutorials/asr/images/vad_post_overlap_diagram.png\" width=\"500\">\n",
                "\n",
                "2. generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments. Perform this step alongside with above step with flag **gen_overlap_seq=True** or use\n",
                "```python\n",
                "<NeMo_git_root>/scripts/voice_activity_detection/vad_overlap_posterior.py\n",
                "```\n",
                "if you already have frame level prediction. \n",
                "\n",
                "Have a look at [MarbleNet paper](https://arxiv.org/pdf/2010.13886.pdf) for choices about segment length, smoothing filter, etc. And play with those parameters with your data.\n",
                "\n",
                "You can also find posterior about converting frame level prediction to speech/no-speech segment in start and end times format in `vad_overlap_posterior.py` or use flag **gen_seg_table=True** alongside with `vad_infer.py`"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### Finetune\n",
                "You might need to finetune on your data for better performance. For finetuning/transfer learning, please refer to [**Transfer learning** part of ASR tutorial](https://github.com/NVIDIA/NeMo/blob/stable/tutorials/asr/ASR_with_NeMo.ipynb)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## VAD postprocessing and Tuning threshold"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "We can use a single **threshold** (achieved by onset=offset=0.5) to binarize predictions or use typical VAD postprocessing including\n",
                "\n",
                "### Binarization:\n",
                "1. **onset** and **offset** threshold for detecting the beginning and end of a speech;\n",
                "2. padding durations before (**pad_onset**) and after (**pad_offset**) each speech segment.\n",
                "\n",
                "### Filtering:\n",
                "1. threshold for short speech segment deletion (**min_duration_on**);\n",
                "2. threshold for small silence deletion (**min_duration_off**);\n",
                "3. Whether to perform short speech segment deletion first (**filter_speech_first**).\n",
                "\n",
                "\n",
                "Of course you can do threshold tuning on frame level prediction. We also provide a script \n",
                "```python\n",
                "<NeMo_git_root>/scripts/voice_activity_detection/vad_tune_threshold.py\n",
                "```\n",
                "\n",
                "to help you find best thresholds if you have ground truth label file in RTTM format. "
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Online streaming inference"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Setting up data for Streaming Inference"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from nemo.core.classes import IterableDataset\n",
                "from nemo.core.neural_types import NeuralType, AudioSignal, LengthsType\n",
                "import torch\n",
                "from torch.utils.data import DataLoader"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# simple data layer to pass audio signal\n",
                "class AudioDataLayer(IterableDataset):\n",
                "    @property\n",
                "    def output_types(self):\n",
                "        return {\n",
                "            'audio_signal': NeuralType(('B', 'T'), AudioSignal(freq=self._sample_rate)),\n",
                "            'a_sig_length': NeuralType(tuple('B'), LengthsType()),\n",
                "        }\n",
                "\n",
                "    def __init__(self, sample_rate):\n",
                "        super().__init__()\n",
                "        self._sample_rate = sample_rate\n",
                "        self.output = True\n",
                "        \n",
                "    def __iter__(self):\n",
                "        return self\n",
                "    \n",
                "    def __next__(self):\n",
                "        if not self.output:\n",
                "            raise StopIteration\n",
                "        self.output = False\n",
                "        return torch.as_tensor(self.signal, dtype=torch.float32), \\\n",
                "               torch.as_tensor(self.signal_shape, dtype=torch.int64)\n",
                "        \n",
                "    def set_signal(self, signal):\n",
                "        self.signal = signal.astype(np.float32)/32768.\n",
                "        self.signal_shape = self.signal.size\n",
                "        self.output = True\n",
                "\n",
                "    def __len__(self):\n",
                "        return 1"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "data_layer = AudioDataLayer(sample_rate=cfg.train_ds.sample_rate)\n",
                "data_loader = DataLoader(data_layer, batch_size=1, collate_fn=data_layer.collate_fn)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# inference method for audio signal (single instance)\n",
                "def infer_signal(model, signal):\n",
                "    data_layer.set_signal(signal)\n",
                "    batch = next(iter(data_loader))\n",
                "    audio_signal, audio_signal_len = batch\n",
                "    audio_signal, audio_signal_len = audio_signal.to(vad_model.device), audio_signal_len.to(vad_model.device)\n",
                "    logits = model.forward(input_signal=audio_signal, input_signal_length=audio_signal_len)\n",
                "    return logits"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# class for streaming frame-based VAD\n",
                "# 1) use reset() method to reset FrameVAD's state\n",
                "# 2) call transcribe(frame) to do VAD on\n",
                "#    contiguous signal's frames\n",
                "# To simplify the flow, we use single threshold to binarize predictions.\n",
                "class FrameVAD:\n",
                "    \n",
                "    def __init__(self, model_definition,\n",
                "                 threshold=0.5,\n",
                "                 frame_len=2, frame_overlap=2.5, \n",
                "                 offset=10):\n",
                "        '''\n",
                "        Args:\n",
                "          threshold: If prob of speech is larger than threshold, classify the segment to be speech.\n",
                "          frame_len: frame's duration, seconds\n",
                "          frame_overlap: duration of overlaps before and after current frame, seconds\n",
                "          offset: number of symbols to drop for smooth streaming\n",
                "        '''\n",
                "        self.vocab = list(model_definition['labels'])\n",
                "        self.vocab.append('_')\n",
                "        \n",
                "        self.sr = model_definition['sample_rate']\n",
                "        self.threshold = threshold\n",
                "        self.frame_len = frame_len\n",
                "        self.n_frame_len = int(frame_len * self.sr)\n",
                "        self.frame_overlap = frame_overlap\n",
                "        self.n_frame_overlap = int(frame_overlap * self.sr)\n",
                "        timestep_duration = model_definition['AudioToMFCCPreprocessor']['window_stride']\n",
                "        for block in model_definition['JasperEncoder']['jasper']:\n",
                "            timestep_duration *= block['stride'][0] ** block['repeat']\n",
                "        self.buffer = np.zeros(shape=2*self.n_frame_overlap + self.n_frame_len,\n",
                "                               dtype=np.float32)\n",
                "        self.offset = offset\n",
                "        self.reset()\n",
                "        \n",
                "    def _decode(self, frame, offset=0):\n",
                "        assert len(frame)==self.n_frame_len\n",
                "        self.buffer[:-self.n_frame_len] = self.buffer[self.n_frame_len:]\n",
                "        self.buffer[-self.n_frame_len:] = frame\n",
                "        logits = infer_signal(vad_model, self.buffer).cpu().numpy()[0]\n",
                "        decoded = self._greedy_decoder(\n",
                "            self.threshold,\n",
                "            logits,\n",
                "            self.vocab\n",
                "        )\n",
                "        return decoded  \n",
                "    \n",
                "    \n",
                "    @torch.no_grad()\n",
                "    def transcribe(self, frame=None):\n",
                "        if frame is None:\n",
                "            frame = np.zeros(shape=self.n_frame_len, dtype=np.float32)\n",
                "        if len(frame) < self.n_frame_len:\n",
                "            frame = np.pad(frame, [0, self.n_frame_len - len(frame)], 'constant')\n",
                "        unmerged = self._decode(frame, self.offset)\n",
                "        return unmerged\n",
                "    \n",
                "    def reset(self):\n",
                "        '''\n",
                "        Reset frame_history and decoder's state\n",
                "        '''\n",
                "        self.buffer=np.zeros(shape=self.buffer.shape, dtype=np.float32)\n",
                "        self.prev_char = ''\n",
                "\n",
                "    @staticmethod\n",
                "    def _greedy_decoder(threshold, logits, vocab):\n",
                "        s = []\n",
                "        if logits.shape[0]:\n",
                "            probs = torch.softmax(torch.as_tensor(logits), dim=-1)\n",
                "            probas, _ = torch.max(probs, dim=-1)\n",
                "            probas_s = probs[1].item()\n",
                "            preds = 1 if probas_s >= threshold else 0\n",
                "            s = [preds, str(vocab[preds]), probs[0].item(), probs[1].item(), str(logits)]\n",
                "        return s"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "\n",
                "\n",
                "Streaming inference depends on a few factors, such as the frame length (STEP) and buffer size (WINDOW SIZE). Experiment with a few values to see their effects in the below cells."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "STEP_LIST =        [0.01,0.01]\n",
                "WINDOW_SIZE_LIST = [0.31,0.15]"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import wave\n",
                "\n",
                "def offline_inference(wave_file, STEP = 0.025, WINDOW_SIZE = 0.5, threshold=0.5):\n",
                "    \n",
                "    FRAME_LEN = STEP # infer every STEP seconds \n",
                "    CHANNELS = 1 # number of audio channels (expect mono signal)\n",
                "    RATE = 16000 # sample rate, Hz\n",
                "    \n",
                "   \n",
                "    CHUNK_SIZE = int(FRAME_LEN*RATE)\n",
                "    \n",
                "    vad = FrameVAD(model_definition = {\n",
                "                   'sample_rate': SAMPLE_RATE,\n",
                "                   'AudioToMFCCPreprocessor': cfg.preprocessor,\n",
                "                   'JasperEncoder': cfg.encoder,\n",
                "                   'labels': cfg.labels\n",
                "               },\n",
                "               threshold=threshold,\n",
                "               frame_len=FRAME_LEN, frame_overlap = (WINDOW_SIZE-FRAME_LEN)/2,\n",
                "               offset=0)\n",
                "\n",
                "    wf = wave.open(wave_file, 'rb')\n",
                "    p = pa.PyAudio()\n",
                "\n",
                "    empty_counter = 0\n",
                "\n",
                "    preds = []\n",
                "    proba_b = []\n",
                "    proba_s = []\n",
                "    \n",
                "    stream = p.open(format=p.get_format_from_width(wf.getsampwidth()),\n",
                "                    channels=CHANNELS,\n",
                "                    rate=RATE,\n",
                "                    output = True)\n",
                "\n",
                "    data = wf.readframes(CHUNK_SIZE)\n",
                "\n",
                "    while len(data) > 0:\n",
                "\n",
                "        data = wf.readframes(CHUNK_SIZE)\n",
                "        signal = np.frombuffer(data, dtype=np.int16)\n",
                "        result = vad.transcribe(signal)\n",
                "\n",
                "        preds.append(result[0])\n",
                "        proba_b.append(result[2])\n",
                "        proba_s.append(result[3])\n",
                "        \n",
                "        if len(result):\n",
                "            print(result,end='\\n')\n",
                "            empty_counter = 3\n",
                "        elif empty_counter > 0:\n",
                "            empty_counter -= 1\n",
                "            if empty_counter == 0:\n",
                "                print(' ',end='')\n",
                "                \n",
                "    p.terminate()\n",
                "    vad.reset()\n",
                "    \n",
                "    return preds, proba_b, proba_s"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### Here we show an example of online streaming inference\n",
                "You can use your file or download the provided demo audio file. "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "scrolled": true
            },
            "outputs": [],
            "source": [
                "demo_wave = 'VAD_demo.wav'\n",
                "if not os.path.exists(demo_wave):\n",
                "    !wget \"https://dldata-public.s3.us-east-2.amazonaws.com/VAD_demo.wav\" "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "wave_file = demo_wave\n",
                "\n",
                "CHANNELS = 1\n",
                "RATE = 16000\n",
                "audio, sample_rate = librosa.load(wave_file, sr=RATE)\n",
                "dur = librosa.get_duration(y=audio, sr=sample_rate)\n",
                "print(dur)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "scrolled": true
            },
            "outputs": [],
            "source": [
                "ipd.Audio(audio, rate=sample_rate)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "scrolled": true
            },
            "outputs": [],
            "source": [
                "threshold=0.4\n",
                "\n",
                "results = []\n",
                "for STEP, WINDOW_SIZE in zip(STEP_LIST, WINDOW_SIZE_LIST, ):\n",
                "    print(f'====== STEP is {STEP}s, WINDOW_SIZE is {WINDOW_SIZE}s ====== ')\n",
                "    preds, proba_b, proba_s = offline_inference(wave_file, STEP, WINDOW_SIZE, threshold)\n",
                "    results.append([STEP, WINDOW_SIZE, preds, proba_b, proba_s])"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "To simplify the flow, the above prediction is based on single threshold and `threshold=0.4`.\n",
                "\n",
                "You can play with other [threshold](#VAD-postprocessing-and-Tuning-threshold) or use postprocessing and see how they would impact performance. \n",
                "\n",
                "**Note** if you want better performance, [finetune](#Finetune) on your data and use posteriors such as [overlapped prediction](#Posterior). \n",
                "\n",
                "Let's plot the prediction and melspectrogram"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import librosa.display\n",
                "plt.figure(figsize=[20,10])\n",
                "\n",
                "num = len(results)\n",
                "for i in range(num):\n",
                "    len_pred = len(results[i][2]) \n",
                "    FRAME_LEN = results[i][0]\n",
                "    ax1 = plt.subplot(num+1,1,i+1)\n",
                "\n",
                "    ax1.plot(np.arange(audio.size) / sample_rate, audio, 'b')\n",
                "    ax1.set_xlim([-0.01, int(dur)+1]) \n",
                "    ax1.tick_params(axis='y', labelcolor= 'b')\n",
                "    ax1.set_ylabel('Signal')\n",
                "    ax1.set_ylim([-1,  1])\n",
                "\n",
                "    proba_s = results[i][4]\n",
                "    pred = [1 if p > threshold else 0 for p in proba_s]\n",
                "    ax2 = ax1.twinx()\n",
                "    ax2.plot(np.arange(len_pred)/(1/results[i][0]), np.array(pred)  , 'r', label='pred')\n",
                "    ax2.plot(np.arange(len_pred)/(1/results[i][0]), np.array(proba_s) ,  'g--', label='speech prob')\n",
                "    ax2.tick_params(axis='y', labelcolor='r')\n",
                "    legend = ax2.legend(loc='lower right', shadow=True)\n",
                "    ax1.set_ylabel('prediction')\n",
                "\n",
                "    ax2.set_title(f'step {results[i][0]}s, buffer size {results[i][1]}s')\n",
                "    ax2.set_ylabel('Preds and Probas')\n",
                "    \n",
                "    \n",
                "ax = plt.subplot(num+1,1,num+1)\n",
                "S = librosa.feature.melspectrogram(y=audio, sr=sample_rate, n_mels=64, fmax=8000)\n",
                "S_dB = librosa.power_to_db(S, ref=np.max)\n",
                "librosa.display.specshow(S_dB, x_axis='time', y_axis='mel', sr=sample_rate, fmax=8000)\n",
                "ax.set_title('Mel-frequency spectrogram')\n",
                "ax.grid()\n",
                "plt.show()"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Online streaming inference through microphone"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "**Please note the VAD model is not perfect for various microphone input and you might need to finetune on your input and play with different parameters.**"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "STEP = 0.01 \n",
                "WINDOW_SIZE = 0.31\n",
                "CHANNELS = 1 \n",
                "RATE = 16000\n",
                "FRAME_LEN = STEP\n",
                "THRESHOLD = 0.5\n",
                "\n",
                "CHUNK_SIZE = int(STEP * RATE)\n",
                "vad = FrameVAD(model_definition = {\n",
                "                   'sample_rate': SAMPLE_RATE,\n",
                "                   'AudioToMFCCPreprocessor': cfg.preprocessor,\n",
                "                   'JasperEncoder': cfg.encoder,\n",
                "                   'labels': cfg.labels\n",
                "               },\n",
                "               threshold=THRESHOLD,\n",
                "               frame_len=FRAME_LEN, frame_overlap=(WINDOW_SIZE - FRAME_LEN) / 2, \n",
                "               offset=0)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "vad.reset()\n",
                "\n",
                "p = pa.PyAudio()\n",
                "print('Available audio input devices:')\n",
                "input_devices = []\n",
                "for i in range(p.get_device_count()):\n",
                "    dev = p.get_device_info_by_index(i)\n",
                "    if dev.get('maxInputChannels'):\n",
                "        input_devices.append(i)\n",
                "        print(i, dev.get('name'))\n",
                "\n",
                "if len(input_devices):\n",
                "    dev_idx = -2\n",
                "    while dev_idx not in input_devices:\n",
                "        print('Please type input device ID:')\n",
                "        dev_idx = int(input())\n",
                "\n",
                "    empty_counter = 0\n",
                "\n",
                "    def callback(in_data, frame_count, time_info, status):\n",
                "        global empty_counter\n",
                "        signal = np.frombuffer(in_data, dtype=np.int16)\n",
                "        text = vad.transcribe(signal)\n",
                "        if len(text):\n",
                "            print(text,end='\\n')\n",
                "            empty_counter = vad.offset\n",
                "        elif empty_counter > 0:\n",
                "            empty_counter -= 1\n",
                "            if empty_counter == 0:\n",
                "                print(' ',end='\\n')\n",
                "        return (in_data, pa.paContinue)\n",
                "\n",
                "    stream = p.open(format=pa.paInt16,\n",
                "                    channels=CHANNELS,\n",
                "                    rate=SAMPLE_RATE,\n",
                "                    input=True,\n",
                "                    input_device_index=dev_idx,\n",
                "                    stream_callback=callback,\n",
                "                    frames_per_buffer=CHUNK_SIZE)\n",
                "\n",
                "    print('Listening...')\n",
                "\n",
                "    stream.start_stream()\n",
                "    \n",
                "    # Interrupt kernel and then speak for a few more words to exit the pyaudio loop !\n",
                "    try:\n",
                "        while stream.is_active():\n",
                "            time.sleep(0.1)\n",
                "    finally:        \n",
                "        stream.stop_stream()\n",
                "        stream.close()\n",
                "        p.terminate()\n",
                "\n",
                "        print()\n",
                "        print(\"PyAudio stopped\")\n",
                "    \n",
                "else:\n",
                "    print('ERROR: No audio input device found.')"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "pycharm": {
                    "name": "#%% md\n"
                }
            },
            "source": [
                "## ONNX Deployment\n",
                "You can also export the model to ONNX file and deploy it to TensorRT or MS ONNX Runtime inference engines. If you don't have one installed yet, please run:"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "!pip install --upgrade onnxruntime # for gpu, use onnxruntime-gpu\n",
                "# !mkdir -p ort\n",
                "# %cd ort\n",
                "# !git clone --depth 1 --branch v1.8.0 https://github.com/microsoft/onnxruntime.git .\n",
                "# !./build.sh --skip_tests --config Release --build_shared_lib --parallel --use_cuda --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu --build_wheel\n",
                "# !pip install ./build/Linux/Release/dist/onnxruntime*.whl\n",
                "# %cd .."
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Then just replace `infer_signal` implementation with this code:"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "pycharm": {
                    "name": "#%%\n"
                }
            },
            "outputs": [],
            "source": [
                "import onnxruntime\n",
                "vad_model.export('vad.onnx')\n",
                "ort_session = onnxruntime.InferenceSession('vad.onnx', providers=['CPUExecutionProvider'])\n",
                "\n",
                "def to_numpy(tensor):\n",
                "    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()\n",
                "\n",
                "def infer_signal(signal):\n",
                "    data_layer.set_signal(signal)\n",
                "    batch = next(iter(data_loader))\n",
                "    audio_signal, audio_signal_len = batch\n",
                "    audio_signal, audio_signal_len = audio_signal.to(vad_model.device), audio_signal_len.to(vad_model.device)\n",
                "    processed_signal, processed_signal_len = vad_model.preprocessor(\n",
                "        input_signal=audio_signal, length=audio_signal_len,\n",
                "    )\n",
                "    ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(processed_signal), }\n",
                "    ologits = ort_session.run(None, ort_inputs)\n",
                "    alogits = np.asarray(ologits)\n",
                "    logits = torch.from_numpy(alogits[0])\n",
                "    return logits"
            ]
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "Python 3",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.7.7"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 4
}
